{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "“TextLSTM-Torch.ipynb”的副本",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "m_C0BgmDtVGm",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "'''\n",
        "  code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathor\n",
        "'''\n",
        "import torch\n",
        "import numpy as np\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.utils.data as Data\n",
        "\n",
        "dtype = torch.FloatTensor"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TRKDPS0WuPRI",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "char_arr = [c for c in 'abcdefghijklmnopqrstuvwxyz'] # ['a', 'b', 'c',...]\n",
        "word2idx = {n: i for i, n in enumerate(char_arr)}\n",
        "idx2word = {i: w for i, w in enumerate(char_arr)}\n",
        "n_class = len(word2idx) # number of class(=number of vocab)\n",
        "\n",
        "seq_data = ['make', 'need', 'coal', 'word', 'love', 'hate', 'live', 'home', 'hash', 'star']\n",
        "\n",
        "# TextLSTM Parameters\n",
        "n_step = len(seq_data[0]) - 1 # (=3)\n",
        "n_hidden = 128\n",
        "\n",
        "def make_data(seq_data):\n",
        "    input_batch, target_batch = [], []\n",
        "\n",
        "    for seq in seq_data:\n",
        "        input = [word2idx[n] for n in seq[:-1]] # 'm', 'a' , 'k' is input\n",
        "        target = word2idx[seq[-1]] # 'e' is target\n",
        "        input_batch.append(np.eye(n_class)[input])\n",
        "        target_batch.append(target)\n",
        "\n",
        "    return torch.Tensor(input_batch), torch.LongTensor(target_batch)\n",
        "\n",
        "input_batch, target_batch = make_data(seq_data)\n",
        "dataset = Data.TensorDataset(input_batch, target_batch)\n",
        "loader = Data.DataLoader(dataset, 3, True)"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sX4M-kbQZtmS",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class TextLSTM(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(TextLSTM, self).__init__()\n",
        "        self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden)\n",
        "        # fc\n",
        "        self.fc = nn.Linear(n_hidden, n_class)\n",
        "\n",
        "    def forward(self, X):\n",
        "        # X: [batch_size, n_step, n_class]\n",
        "        batch_size = X.shape[0]\n",
        "        input = X.transpose(0, 1)  # X : [n_step, batch_size, n_class]\n",
        "\n",
        "        hidden_state = torch.zeros(1, batch_size, n_hidden)   # [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n",
        "        cell_state = torch.zeros(1, batch_size, n_hidden)    # [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n",
        "\n",
        "        outputs, (_, _) = self.lstm(input, (hidden_state, cell_state))\n",
        "        outputs = outputs[-1]  # [batch_size, n_hidden]\n",
        "        model = self.fc(outputs)\n",
        "        return model\n",
        "\n",
        "\n",
        "model = TextLSTM()\n",
        "criterion = nn.CrossEntropyLoss()\n",
        "optimizer = optim.Adam(model.parameters(), lr=0.001)"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "D9r4DcovvGLR",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 689
        },
        "outputId": "4afd46f9-92bf-42e3-eb04-8142d152c127"
      },
      "source": [
        "# Training\n",
        "for epoch in range(1000):\n",
        "  for x, y in loader:\n",
        "    pred = model(x)\n",
        "    loss = criterion(pred, y)\n",
        "    if (epoch + 1) % 100 == 0:\n",
        "        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n",
        "\n",
        "    optimizer.zero_grad()\n",
        "    loss.backward()\n",
        "    optimizer.step()"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch: 0100 cost = 0.021760\n",
            "Epoch: 0100 cost = 0.214454\n",
            "Epoch: 0100 cost = 0.016971\n",
            "Epoch: 0100 cost = 0.040900\n",
            "Epoch: 0200 cost = 0.003466\n",
            "Epoch: 0200 cost = 0.041558\n",
            "Epoch: 0200 cost = 0.002600\n",
            "Epoch: 0200 cost = 0.002137\n",
            "Epoch: 0300 cost = 0.007363\n",
            "Epoch: 0300 cost = 0.000711\n",
            "Epoch: 0300 cost = 0.008049\n",
            "Epoch: 0300 cost = 0.000365\n",
            "Epoch: 0400 cost = 0.003190\n",
            "Epoch: 0400 cost = 0.003718\n",
            "Epoch: 0400 cost = 0.000906\n",
            "Epoch: 0400 cost = 0.000536\n",
            "Epoch: 0500 cost = 0.000352\n",
            "Epoch: 0500 cost = 0.001526\n",
            "Epoch: 0500 cost = 0.000329\n",
            "Epoch: 0500 cost = 0.008434\n",
            "Epoch: 0600 cost = 0.001356\n",
            "Epoch: 0600 cost = 0.000167\n",
            "Epoch: 0600 cost = 0.000261\n",
            "Epoch: 0600 cost = 0.003648\n",
            "Epoch: 0700 cost = 0.000073\n",
            "Epoch: 0700 cost = 0.000188\n",
            "Epoch: 0700 cost = 0.000854\n",
            "Epoch: 0700 cost = 0.002573\n",
            "Epoch: 0800 cost = 0.000152\n",
            "Epoch: 0800 cost = 0.000144\n",
            "Epoch: 0800 cost = 0.000526\n",
            "Epoch: 0800 cost = 0.001806\n",
            "Epoch: 0900 cost = 0.000379\n",
            "Epoch: 0900 cost = 0.000104\n",
            "Epoch: 0900 cost = 0.000546\n",
            "Epoch: 0900 cost = 0.000076\n",
            "Epoch: 1000 cost = 0.000081\n",
            "Epoch: 1000 cost = 0.000413\n",
            "Epoch: 1000 cost = 0.000310\n",
            "Epoch: 1000 cost = 0.000003\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6InTlcP8u7Pv",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "5775d2c1-af58-4210-a70a-2716ffa3bc6f"
      },
      "source": [
        "inputs = [sen[:3] for sen in seq_data]\n",
        "predict = model(input_batch).data.max(1, keepdim=True)[1]\n",
        "print(inputs, '->', [idx2word[n.item()] for n in predict.squeeze()])"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['mak', 'nee', 'coa', 'wor', 'lov', 'hat', 'liv', 'hom', 'has', 'sta'] -> ['e', 'd', 'l', 'd', 'e', 'e', 'e', 'e', 'h', 'r']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "96TCSLLSvB_n",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}